package nsalt.mem

import chisel3._
import chisel3.util._

import nsalt._
import nsalt.bus._
import nsalt.util._

// Ref:
// https://oscpu.github.io/NutShell-doc/%E5%8A%9F%E8%83%BD%E9%83%A8%E4%BB%B6/tlb.html


// SV39: For more details:
// RISC-V SPEC (V1.7) Volume II: RISC-V Privileged Architectures Section 4.6 (P46)

trait VirtMemSchema extends Constants{
  val Level = 3
  val pageOffsetLen  = 12

  val physPageNum0Len = 9
  val physPageNum1Len = 9
  val physPageNum2Len = PHYS_ADDR_LEN - pageOffsetLen - physPageNum0Len - physPageNum1Len // 2
  val physPageNumLen = physPageNum2Len + physPageNum1Len + physPageNum0Len

  val virtPageNum2Len = 9
  val virtPageNum1Len = 9
  val virtPageNum0Len = 9
  val virtPageNumLen = virtPageNum2Len + virtPageNum1Len + virtPageNum0Len
  
  // SPEC: 4.1.12 Supervisor Address Translation and Protection (satp) Register
  val satpLen = XLEN
  val satpModeLen = 4

  // SPEC: 4.1.10 Supervisor Address Space ID Register (sasid)
  val asidLen = 16

  // flag: Check Figure 4.18: Sv39 page table entry.
  //       D / A / G / U / X / W / R / V
  val flagLen = 8

  // PTE: Page Table Entry
  val pageTableEntryLen = XLEN
  val satpRemLen = XLEN - physPageNumLen - satpModeLen - asidLen
  //val vaResLen = 25 // unused
  //val paResLen = 25 // unused 

  val pteRemLen  = XLEN - physPageNumLen - 2 - flagLen


  def virtPageNumBundle = new Bundle {
    val virtPageNum2 = UInt(virtPageNum2Len.W)
    val virtPageNum1 = UInt(virtPageNum1Len.W)
    val virtPageNum0 = UInt(virtPageNum0Len.W)
  }

  def vaBundle = new Bundle {
    val virtPageNum2 = UInt(virtPageNum2Len.W)
    val virtPageNum1 = UInt(virtPageNum1Len.W)
    val virtPageNum0 = UInt(virtPageNum0Len.W)
    val off  = UInt( pageOffsetLen.W)
  }

  def vaBundle2 = new Bundle {
    val virtPageNum  = UInt(virtPageNumLen.W)
    val off  = UInt(pageOffsetLen.W)
  }

  def paBundle = new Bundle {
    val physPageNum2 = UInt(physPageNum2Len.W)
    val physPageNum1 = UInt(physPageNum1Len.W)
    val physPageNum0 = UInt(physPageNum0Len.W)
    val off  = UInt( pageOffsetLen.W)
  }

  def paBundle2 = new Bundle {
    val physPageNum  = UInt(physPageNumLen.W)
    val off  = UInt(pageOffsetLen.W)
  }

  def physAddrApply(physPageNum: UInt, virtPageNumn: UInt):UInt = {
    Cat(Cat(physPageNum, virtPageNumn), 0.U(3.W))
  }
  
  def pteBundle = new Bundle {
    val reserved  = UInt(pteRemLen.W)
    val physPageNum  = UInt(physPageNumLen.W)
    val rsw  = UInt(2.W)
    val flag = new Bundle {
      val d    = UInt(1.W)
      val a    = UInt(1.W)
      val g    = UInt(1.W)
      val u    = UInt(1.W)
      val x    = UInt(1.W)
      val w    = UInt(1.W)
      val r    = UInt(1.W)
      val v    = UInt(1.W)
    }
  }

  def satpBundle = new Bundle {
    val mode = UInt(satpModeLen.W)
    val asid = UInt(asidLen.W)
    val res = UInt(satpRemLen.W)
    val physPageNum  = UInt(physPageNumLen.W)
  }


  def flagBundle = new Bundle  {
    val d    = Bool()//UInt(1.W)
    val a    = Bool()//UInt(1.W)
    val g    = Bool()//UInt(1.W)
    val u    = Bool()//UInt(1.W)
    val x    = Bool()//UInt(1.W)
    val w    = Bool()//UInt(1.W)
    val r    = Bool()//UInt(1.W)
    val v    = Bool()//UInt(1.W)
  }

  def maskPaddr(physPageNum:UInt, vaddr:UInt, mask:UInt) = {
    MaskData(vaddr, Cat(physPageNum, 0.U(pageOffsetLen.W)), Cat(Fill(physPageNum2Len, 1.U(1.W)), mask, 0.U(pageOffsetLen.W)))
  }

  def MaskEQ(mask: UInt, pattern: UInt, virtPageNum: UInt) = {
    (Cat("h1ff".U(virtPageNum2Len.W), mask) & pattern) === (Cat("h1ff".U(virtPageNum2Len.W), mask) & virtPageNum)
  }

}


sealed case class TransLookasideConf (
  NAME: String = "tlb",
  USER_BITS: Int = 0,

  TOTAL_ENTRY: Int = 4,
  WAYS: Int = 4
)


trait TransLookasideConst extends VirtMemSchema{
  
  implicit val conf: TransLookasideConf

  val ADDR_BITS: Int
  val PHYS_ADDR_LEN: Int
  val VIRT_MEM_ADDR_LEN: Int
  val XLEN: Int

  val TLB_NAME = conf.NAME
  val USER_BITS = conf.USER_BITS

  val maskLen = virtPageNum0Len + virtPageNum1Len  // 18
  val metaLen = virtPageNumLen + asidLen + maskLen + flagLen // 27 + 16 + 18 + 8 = 69, is asid necessary 
  val dataLen = physPageNumLen + PHYS_ADDR_LEN // 
  val tlbLen  = metaLen + dataLen

  val WAYS = conf.WAYS
  val TOTAL_ENTRY = conf.TOTAL_ENTRY
  val Sets = TOTAL_ENTRY / WAYS
  val IndexBits = log2Up(Sets)
  val TagBits = virtPageNumLen - IndexBits

  val debug = false //&& tlbname == "dtlb"

  def vaddrTlbBundle = new Bundle {
    val tag = UInt(TagBits.W)
    val index = UInt(IndexBits.W)
    val off = UInt(pageOffsetLen.W)
  }

  def metaBundle = new Bundle {
    val virtPageNum = UInt(virtPageNumLen.W)
    val asid = UInt(asidLen.W)
    val mask = UInt(maskLen.W) // to support super page
    val flag = UInt(flagLen.W)
  }

  def dataBundle = new Bundle {
    val physPageNum = UInt(physPageNumLen.W)
    val entryAddr = UInt(PHYS_ADDR_LEN.W) // pte addr, used to write back pte when flag changes (flag.d, flag.v)
  }

  def tlbBundle = new Bundle {
    val virtPageNum = UInt(virtPageNumLen.W)
    val asid = UInt(asidLen.W)
    val mask = UInt(maskLen.W)
    val flag = UInt(flagLen.W)

    val physPageNum = UInt(physPageNumLen.W)
    val entryAddr = UInt(PHYS_ADDR_LEN.W)
  }

  def tlbBundle2 = new Bundle {
    val meta = UInt(metaLen.W)
    val data = UInt(dataLen.W)
  }

  def getIndex(vaddr: UInt) : UInt = {
    vaddr.asTypeOf(vaddrTlbBundle).index
  }
}

class TransLookasidePort(implicit val conf: TransLookasideConf)  extends Bundle with TransLookasideConst{
  val in = Flipped(new BusUncached(userBits = USER_BITS, addrBits = VIRT_MEM_ADDR_LEN))
  val out = new BusUncached(userBits = USER_BITS)

  val mem = new BusUncached()
  val flush = Input(Bool())
  val csrMMU = new MemManagePort
  val cacheEmpty = Input(Bool())
  val ipf = Output(Bool())
}
